83cc69
@@ -25,13 +25,11 @@
 
 import org.apache.hadoop.hive.ql.exec.Description;
 import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
-import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
 import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
 import org.apache.hadoop.hive.ql.metadata.HiveException;
 import org.apache.hadoop.hive.serde.serdeConstants;
 import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
@@ -60,22 +58,19 @@
public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumen
     GenericUDFUtils.ReturnObjectInspectorResolver returnOIResolver;
     returnOIResolver = new GenericUDFUtils.ReturnObjectInspectorResolver(true);
 
-    if (arguments.length != 1) {
-      throw new UDFArgumentLengthException(
-        "The function SORT_ARRAY(array(obj1, obj2,...)) needs one argument.");
-    }
+    checkArgsSize(arguments, 1, 1);
 
     switch(arguments[0].getCategory()) {
       case LIST:
-        if(((ListObjectInspector)(arguments[0])).getListElementObjectInspector()
-          .getCategory().equals(Category.PRIMITIVE)) {
+        if(!((ListObjectInspector)(arguments[0])).getListElementObjectInspector()
+            .getCategory().equals(ObjectInspector.Category.UNION)) {
           break;
         }
       default:
         throw new UDFArgumentTypeException(0, "Argument 1"
-          + " of function SORT_ARRAY must be " + serdeConstants.LIST_TYPE_NAME
-          + "<" + Category.PRIMITIVE + ">, but " + arguments[0].getTypeName()
-          + " was found.");
+            + " of function SORT_ARRAY must be " + serdeConstants.LIST_TYPE_NAME
+            + ", and element type should be either primitive, list, struct, or map, " +
+            "but " + arguments[0].getTypeName() + " was found.");
     }
 
     ObjectInspector elementObjectInspector =
